"""
Filename: dynamics_lumpsum.jl
Paper: Unemployment and the Distribution of Liquidity
Authors: Zach Bethune and Guillaume Rocheteau
Contact: bethune@rice.edu
Last Modified: 8/22/23
Purpose: functions to compute transitional dynamics
"""

###############################################################################
# Blocks of DAG (see appendix B)
###############################################################################
function fiscal_block(Rft,Rmt,phimt,Agt,Mt;mm::model_outcomes,M_1=1.0)
    T=length(Rft)
    taut=zeros((T,2))
    taut[1,1] = (1.0/Rft[1])*Agt[1] - mm.Ag + (phimt[1])*(Mt[1]-M_1)
    taut[2:T,1] =  (1.0./Rft[2:T]).*Agt[2:T].-Agt[1:T-1] .+ phimt[2:T].*(Mt[2:T].-Mt[1:T-1])
    taut[:,2]=copy(taut[:,1])
    return taut
end

function wage_block(pyt; parms::parmsmodel,mm::model_outcomes,md::moments_data)
    T=length(pyt); #pyt indexes py in period 2 through T+1
    wagest=zeros(2,Nz,T); 
    
    #first period wage determined by py in previous steady state
    wagest[:,:,1] = mm.wages_bar
    
    #period 2 through last stage in period T
    Yst = κ_prime_inv.(pyt[1:T-1];parms);
    frev = mm.zgrid.*(1.0 .+ pyt[1:T-1].*Yst .- κ.(Yst;parms))';
    wagest[1,:,2:end] = md.LABOR_SHARE.*frev;
    wagest[2,:,2:end] = md.REPLACE_RATE.*wagest[1,:,2:end];
    return wagest
end

function phim_block(Rmt,Mt;mm_new::model_outcomes)
    T=length(Rmt);
    phimt=zeros(T);
    phimt_Tp1 = (mm_new.Zmd*mm_new.Rm)/(Rmt[T]*Mt[T])
    phimt[T] = phimt_Tp1/Rmt[T];
    for t=T-1:-1:1
        phimt[t] = phimt[t+1]/Rmt[t]
    end
    return phimt
end

function firms_block(pyt,Rft;parms::parmsmodel, mm_new::model_outcomes)
    T=length(pyt)
    Yst = κ_prime_inv.(pyt;parms);
    frev = mm.zgrid.*(1.0 .+ pyt.*Yst .- κ.(Yst;parms))';
    wage1 = md.LABOR_SHARE.*frev;
    thetat = zeros(T+1);
    thetat[T+1] = mm_new.theta
    for i=T:-1:1
        temp=(frev[1,i]-wage1[1,i])/(mm.zgrid[1].*Rft[i]*parms.K) + ((1.0-parms.DELTA)*f(thetat[i+1];parms))/Rft[i]
        if temp>=1
            thetat[i] = max(f_inv.(temp;parms),1e-6);
        else
            thetat[i] = 1e-6
        end
    end

    phift=zeros(T,Nz);
    phift[1:T,:] = (mm.zgrid.*(parms.K.*Rft[1:T].*f.(thetat[1:T];parms))')'
    thetat=thetat[1:T]
    return Yst, thetat, phift
end

function unemp_block(thetat;parms::parmsmodel,mm::model_outcomes)
    T=length(thetat)
    ut=zeros(T+1)
    ut[1]=(1.0-mm.emp)
    for i=2:T+1
        ut[i] = (1.0-λ(thetat[i-1];parms))*ut[i-1] + parms.DELTA*(1.0-ut[i-1])
    end
    ut=ut[2:T+1]
    return ut
end

function ha_block(pyt,Rft,Rmt,taut,wagest,thetat;parms::parmsmodel,mm::model_outcomes,mm_new::model_outcomes)
    T=length(pyt)

    #initialize classes and matricies
    mm_trans = deepcopy(mm); #placeholder for household and firm objects
    parms_trans = deepcopy(parms); #placeholder for parameters
    am_p_mat = zeros((T,Na,2,Nz));
    a_p_mat = zeros((T,Na,2,Nz));
    c_mat = zeros((T,Na,2,Nz));
    ystar_mat = zeros((T,Na,2,Nz));
    Wa_mat = zeros((T+1,Na,2,Nz));
    
    #solve household problem backwards in time 
    Wa_mat[T+1,:,:,:] = copy(mm_new.Wa);
    for i=T:-1:1
        mm_trans.Rm = Rmt[i];
        mm_trans.py = pyt[i];
        mm_trans.Rf = Rft[i];
        mm_trans.tau[:,1,:] .= taut[i,1];
        mm_trans.tau[:,2,:] .= taut[i,2];
        mm_trans.wages_bar = wagest[:,:,i];
        mm_trans.theta = thetat[i];

        Wa_mat[i,:,:,:], am_p_mat[i,:,:,:], a_p_mat[i,:,:,:], c_mat[i,:,:,:], ystar_mat[i,:,:,:] = Wa_iter(; py=pyt[i],Rf=Rft[i],Wa_p=Wa_mat[i+1,:,:,:],agrid=mm_trans.agrid,parms=parms_trans,mm=mm_trans)
    end
    
    #solve for distributions and aggregates forward in time
    Ydt=zeros(T); Amt=zeros(T); Aft=zeros(T);
    Ct=zeros(T); 
    Amt0=zeros(T); Aft0=zeros(T); Ct0=zeros(T);
    Amt1=zeros(T); Aft1=zeros(T); Ct1=zeros(T);
    gt = zeros((T+1,Na,2,Nz));
    gt[1,:,:,:] = copy(mm.g);

    g=copy(mm.g)
    for i=1:T
        Aft[i]=sum(g.*(a_p_mat[i,:,:,:].-am_p_mat[i,:,:,:]));
        Amt[i]=sum(g.*am_p_mat[i,:,:,:]);
        Ct[i]=sum(g.*c_mat[i,:,:,:]);

        Aft0[i]=sum(g[:,2,:].*(a_p_mat[i,:,2,:].-am_p_mat[i,:,2,:]))./sum(g[:,2,:]);
        Aft1[i]=sum(g[:,1,:].*(a_p_mat[i,:,1,:].-am_p_mat[i,:,1,:]))/sum(g[:,1,:]);
        Amt0[i]=sum(g[:,2,:].*am_p_mat[i,:,2,:])/sum(g[:,2,:]);
        Amt1[i]=sum(g[:,1,:].*am_p_mat[i,:,1,:])/sum(g[:,1,:]);
        Ct0[i]=sum(g[:,2,:].*c_mat[i,:,2,:])/sum(g[:,2,:]);
        Ct1[i]=sum(g[:,1,:].*c_mat[i,:,1,:])/sum(g[:,1,:]);

        mm_trans.PI = [(1.0-parms_trans.DELTA) parms_trans.DELTA; λ(thetat[i];parms=parms_trans) 1.0-λ(thetat[i];parms=parms_trans)]
        Ydt[i]=0.0
        for l=1:Nz
            for k=1:2
                for kk=1:2
                    ystar_int = LinearInterpolation(mm_trans.agrid, ystar_mat[i,:,kk,l], extrapolation_bc=Line());
                    for j=1:Na
                        Ydt[i] += mm_trans.PI[k,kk]*parms_trans.ALPHA*parms_trans.ALPHAmf*ymf(ystar_int,a_p_mat[i,j,k,l],am_p_mat[i,j,k,l],pyt[i])[1]*g[j,k,l]
                        Ydt[i] += mm_trans.PI[k,kk]*parms_trans.ALPHA*(1.0-parms_trans.ALPHAmf)*ym(ystar_int,a_p_mat[i,j,k,l],am_p_mat[i,j,k,l],pyt[i])[1]*g[j,k,l]
                    end
                end
            end
        end

        mm_trans.am_p = copy(am_p_mat[i,:,:,:])
        mm_trans.a_p = copy(a_p_mat[i,:,:,:])
        mm_trans.ystar = copy(ystar_mat[i,:,:,:])
        mm_trans.c = copy(c_mat[i,:,:,:])
        transition_g!(mm_trans;py=pyt[i],parms=parms_trans); 
        ghat = reshape(g,(Na*2*Nz,1));
        ghat_prime = mm_trans.T*ghat;
        g_prime = reshape(ghat_prime,(Na,2,Nz));
        gt[i+1,:,:,:]=copy(g_prime);
        g=deepcopy(g_prime)
    end

    return Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt
end
    
function update_trans(Rmt,Rft,pyt,Agt,Mt;mm::model_outcomes,mm_new::model_outcomes,parms::parmsmodel,md::moments_data,M_1=1)
    #update phim block
    phimt = phim_block(Rmt,Mt;mm_new)

    #update fiscal block
    taut = fiscal_block(Rft,Rmt,phimt,Agt,Mt;mm,M_1=M_1)
    
    #update wage block
    wagest = wage_block(pyt; parms,mm,md)
    
    #update firm block
    Yst, thetat, phift = firms_block(pyt,Rft;parms,mm_new)

    #update unemployment block
    ut = unemp_block(thetat;parms,mm)

    #update HA block
    Ydt, Aft, Amt, a_p_mat, am_p_mat, c_mat, ystar_mat, Wa_mat, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, gt = ha_block(pyt,Rft,Rmt,taut,wagest,thetat;parms,mm,mm_new)

    return taut, wagest, phimt, Yst, thetat, phift, ut, Ydt, Aft, Amt, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, a_p_mat, am_p_mat, c_mat, ystar_mat, gt
end

function mkt_clear_trans(phimt,Rmt,Rft,Mt,Amt,ut,phift,Agt,Aft,Yst,Ydt; parms::parmsmodel,mm::model_outcomes,mm_new::model_outcomes)
    T=length(phimt)
    H1t=zeros(T); H2t=zeros(T); H3t=zeros(T);
    
    #H1: mkt clearing liquid wealth
    H1t = phimt.*Rmt.*Mt.-Amt
    
    #H2: mkt clearing illiquid wealth
    H2t[1:T] = (1.0.-ut[1:T]).*(mm.gz'*phift')'.+ Agt[1:T] .- Aft[1:T];

    #H3: mkt clearing early consumption
    H3t[1:T] = (1.0.-ut[1:T]).*sum(mm.zgrid.*mm.gz).*Yst[1:T].-Ydt[1:T];
    
    return H1t, H2t, H3t
end

###############################################################################
# Jacobians of blocks of DAG (see appendix B)
###############################################################################

#HA jacobian (using Auclert, et al. (2021)'s fake news algorithm)
function ha_jacobian(Rmt,Rft,pyt,taut,wagest,thetat;dx,mm::model_outcomes,mm_new::model_outcomes,parms::parmsmodel)
    #define perturbation at date T-1
    T=length(pyt);
    M=6+2*Nz;
    X = zeros(M,T);
    es=zeros(T); es[T]=1.0;

    #define policy responses to shock at s=T-1
    mm_trans = deepcopy(mm_new); #placeholder for household and firm objects
    parms_trans = deepcopy(parms); #placeholder for parameters
    am_p_mat = zeros((M,2,T,Na,2,Nz));
    a_p_mat = zeros((M,2,T,Na,2,Nz));
    c_mat = zeros((M,2,T,Na,2,Nz));
    ystar_mat = zeros((M,2,T+1,Na,2,Nz));
    Wa_mat = zeros((M,2,T+1,Na,2,Nz));
    T_mat=zeros((M,2,T,Na*2*Nz,Na*2*Nz));
    dYs0 = zeros((M,3,T));
    thetavec = zeros((T,2));
    pytvec = zeros((T,2));

    #computes two-sided numerical derivatives
    for j=1:M
        for n=1:2
            X[1,:] = copy(pyt);
            X[2,:] = copy(Rft);
            X[3,:] = copy(Rmt);
            X[4,:] = copy(taut[:,1]);
            X[5,:] = copy(taut[:,2]);
            X[6,:] = copy(thetat);
            for k=0:Nz-1
                X[6+(2*k)+1,:] = copy(wagest[1,k+1,:]);
                X[6+(2*k)+2,:] = copy(wagest[2,k+1,:]);
            end
            if n==1
                X[j,:]=X[j,:].+es.*dx; 
            else
                X[j,:]=X[j,:].-es.*dx
            end
            pyt2=copy(X[1,:]);
            Rft2=copy(X[2,:]);
            Rmt2=copy(X[3,:]);
            taut2=deepcopy(taut);
            taut2[:,1]=copy(X[4,:]);
            taut2[:,2]=copy(X[5,:]);
            thetat2=copy(X[6,:]);
            wagest2=deepcopy(wagest)
            for k=0:Nz-1
                wagest2[1,k+1,:] = X[6+(2*k)+1,:]
                wagest2[2,k+1,:] = X[6+(2*k)+2,:]
            end
            thetavec[:,n]=copy(X[6,:])
            pytvec[:,n]=copy(X[1,:])

            #assume policy functions right after shock return to steady state policies
            Wa_mat[j,n,T+1,:,:,:] = copy(mm_new.Wa);

            #solve for policies in time 0 if shock will occur in i periods
            for i=T:-1:1
                mm_trans.Rm = copy(Rmt2[i]);
                mm_trans.py = copy(pyt2[i]);
                mm_trans.Rf = copy(Rft2[i]);
                mm_trans.tau[:,1,:] .= copy(taut2[i,1]); 
                mm_trans.tau[:,2,:] .= copy(taut2[i,2]);
                mm_trans.wages_bar = copy(wagest2[:,:,i]); 
                mm_trans.theta = copy(thetat2[i]);

                Wa_mat[j,n,i,:,:,:], am_p_mat[j,n,i,:,:,:], a_p_mat[j,n,i,:,:,:], c_mat[j,n,i,:,:,:], ystar_mat[j,n,i,:,:,:] = Wa_iter(;py=pyt2[i],Rf=Rft2[i],Wa_p=Wa_mat[j,n,i+1,:,:,:],agrid=mm_trans.agrid,parms=parms_trans,mm=mm_trans)
            end

            #Solve for transition matricies (the matrix T in period i, is the transition matrix between wealth at the beginning of the late stage in period i to the last stage in period i+1)
            for i=1:T
                mm_trans.PI = [(1.0-parms_trans.DELTA) parms_trans.DELTA; λ(thetat2[i];parms=parms_trans) 1.0-λ(thetat2[i];parms=parms_trans)]
                mm_trans.ystar = copy(ystar_mat[j,n,i,:,:,:])
                mm_trans.am_p = copy(am_p_mat[j,n,i,:,:,:])
                mm_trans.a_p = copy(a_p_mat[j,n,i,:,:,:])
                mm_trans.c = copy(c_mat[j,n,i,:,:,:])
                transition_g!(mm_trans;py=pyt2[i],parms=parms_trans);
                T_mat[j,n,i,:,:] = copy(mm_trans.T)
            end
        end

        #Update dYs0
        for s=0:T-1
            #liquid assets
            dy = (am_p_mat[j,1,T-s,:,:,:].-am_p_mat[j,2,T-s,:,:,:])./2.0;
            dYs0[j,1,s+1] = sum(dy.*mm.g);

            #illiquid assets
            dy = ((a_p_mat[j,1,T-s,:,:,:].-am_p_mat[j,1,T-s,:,:,:]).-(a_p_mat[j,2,T-s,:,:,:].-am_p_mat[j,2,T-s,:,:,:]))./2.0;
            dYs0[j,2,s+1] = sum(dy.*mm.g);
        
            #early consumption
            dYs0[j,3,s+1] = 0.0;
            for l=1:Nz
                for k=1:2
                    for kk=1:2
                        ystar_int_minus = LinearInterpolation(mm_trans.agrid, ystar_mat[j,2,T-s,:,kk,l], extrapolation_bc=Line());
                        ystar_int_plus = LinearInterpolation(mm_trans.agrid, ystar_mat[j,1,T-s,:,kk,l], extrapolation_bc=Line());
                        PI_plus = [(1.0-parms_trans.DELTA) parms_trans.DELTA; λ(thetavec[T-s,1];parms=parms_trans) 1.0-λ(thetavec[T-s,1];parms=parms_trans)];
                        PI_minus = [(1.0-parms_trans.DELTA) parms_trans.DELTA; λ(thetavec[T-s,2];parms=parms_trans) 1.0-λ(thetavec[T-s,2];parms=parms_trans)];
                        for n=1:Na
                            dYs0[j,3,s+1] = dYs0[j,3,s+1] + parms_trans.ALPHA*parms_trans.ALPHAmf*((PI_plus[k,kk]*ymf(ystar_int_plus,a_p_mat[j,1,T-s,n,k,l],am_p_mat[j,1,T-s,n,k,l],pytvec[T-s,1])[1].-PI_minus[k,kk]*ymf(ystar_int_minus,a_p_mat[j,2,T-s,n,k,l],am_p_mat[j,2,T-s,n,k,l],pytvec[T-s,2])[1])./2.0).*mm.g[n,k,l]

                            dYs0[j,3,s+1] = dYs0[j,3,s+1] + parms_trans.ALPHA*(1.0-parms_trans.ALPHAmf)*((PI_plus[k,kk]*ym(ystar_int_plus,a_p_mat[j,1,T-s,n,k,l],am_p_mat[j,1,T-s,n,k,l],pytvec[T-s,1])[1].-PI_minus[k,kk]*ym(ystar_int_minus,a_p_mat[j,2,T-s,n,k,l],am_p_mat[j,2,T-s,n,k,l],pytvec[T-s,2])[1])./2.0).*mm.g[n,k,l]
                        end
                    end
                end
            end
        end
    end

    #Update Dsdx
    Dsdx = zeros((M,T,Na*2*Nz));
    ghatss = reshape(mm.g,(Na*2*Nz,1));
    for j=1:M
        for s=0:T-1
            Dsdx[j,s+1,:] = ((T_mat[j,1,T-s,:,:].-T_mat[j,2,T-s,:,:])./2.0)*ghatss
        end
    end
    
    #Calculate expectation vectors
    Et = zeros((6,T-1,Na*2*Nz));
    Et[1,1,:] = copy(reshape(mm.am_p,(Na*2*Nz,1))); #expected path of am
    Et[4,1,:] = copy(reshape(mm.a_p,(Na*2*Nz,1))); #expected path of a
    Et[2,1,:] = Et[4,1,:].-Et[1,1,:]; #expected path of af
    Et[5,1,:] = copy(reshape(mm.c,(Na*2*Nz,1))); #expected path of c
    for l=1:Nz
        for k=1:2 #employment state in t=0
            for j=1:Na #wealth in beginning of last stage in t=0
                Et[6,1,(k-1)*Na+j + (l-1)*(2*Na)]=0.0
                Et[3,1,(k-1)*Na+j + (l-1)*(2*Na)]=0.0
                for kk=1:2
                    ystar_int = LinearInterpolation(mm.agrid, mm.ystar[:,kk,l], extrapolation_bc=Line());
                    Et[6,1,(k-1)*Na+j + (l-1)*(2*Na)] += mm.PI[k,kk]*parms_base.ALPHA*ystar_int(mm.a_p[j,k,l]) #expected path of ystar
                    Et[3,1,(k-1)*Na+j + (l-1)*(2*Na)] += mm.PI[k,kk]*parms_base.ALPHA*parms_base.ALPHAmf*ymf(ystar_int,mm.a_p[j,k,l],mm.am_p[j,k,l],mm.py)[1] #expected path of yd
                    Et[3,1,(k-1)*Na+j + (l-1)*(2*Na)] += mm.PI[k,kk]*parms_base.ALPHA*(1.0-parms_base.ALPHAmf)*ym(ystar_int,mm.a_p[j,k,l],mm.am_p[j,k,l],mm.py)[1] #expected path of yd
                end
            end
        end
    end
    for i=2:T-1
        for j=1:6
            Et[j,i,:] = (mm.T)'*Et[j,i-1,:];
        end
    end

    #Construct fake-news matrix
    F = zeros((T,T,M,3));
    for i=1:M
        for j=1:3
            F[1,:,i,j] .=  dYs0[i,j,:]./dx
            for t=2:T
                for s=1:T
                    F[t,s,i,j] = Et[j,t-1,:]'*(Dsdx[i,s,:]./dx)
                end
            end
        end
    end

    #Define Jacobians
    J_het = zeros((T,T,M,3));
    for i=1:M
        for j=1:3
            J_het[1,:,i,j] = F[1,:,i,j];
            J_het[:,1,i,j] = F[:,1,i,j];
            for t=2:T
                for s=2:T
                    J_het[t,s,i,j] = J_het[t-1,s-1,i,j] + F[t,s,i,j];
                end
            end
        end
    end

    return am_p_mat, a_p_mat, c_mat, ystar_mat, Wa_mat, T_mat, dYs0, Dsdx, Et, F, J_het
end

#Fiscal Jacobian
function fiscal_jacobian(Agt,Mt,Rft,phimt;mm::model_outcomes,M_1=1.0)
    #inputs: Rft, Agt (uses simlar structure as lump-sum model)
    #outputs: taut
    J = zeros((T,T,4));

    #response of taut to Ag
    J[:,:,1] = diagm(-1=>-ones(T-1),0=>1.0./Rft[1:T]);

    #response of taut to M
    J[:,:,2] = diagm(-1=>-phimt[2:T],0=>phimt[1:T]);

    #response of taut to Rf
    J[:,:,3] = diagm(0=> -Agt[1:T]./((Rft[1:T]).^2));

    #response of taut to phim
    J[2:end,2:end,4] = diagm(0=>Mt[2:T]-Mt[1:T-1]);
    J[1,1,4] = Mt[1]-M_1;

    return J
end
            
#Wage Jacobian
function wage_jacobian(pyt;parms::parmsmodel,md::moments_data)
    #Since w0=replace_rate*w1, this jacobian only computes wrt to w1 
    J=zeros((T,T,Nz));
    for i=1:Nz
        J[:,:,i]=diagm(-1=>md.LABOR_SHARE.*mm.zgrid[i].*κ_prime_inv.(pyt[1:T-1];parms))
    end
    return J
end

#Money price Jacobian
function phim_jacobian(Rmt,phimt,Mt)
    T=length(Rmt);
    J=zeros((T,T,2));
    
    #J_phimt_Rmt
    J[:,:,1] = diagm(0=>-phimt./Rmt)
    for j=T:-1:2
        for i=j-1:-1:1
            J[i,j,1] = J[i+1,j,1]/Rmt[i]
        end
    end

    #J_phimt_Ms
    J[T,T,2] = -phimt[T]/Mt[T]
    for t=T-1:-1:1
        J[t,T,2] = J[t+1,T,2]/Rmt[t]
    end

    return J
end

#Firm Jocobian
function firm_jacobian(pyt,Rft,thetat;parms::parmsmodel,mm_new::model_outcomes,md::moments_data)
    #inputs: pyt, Rft
    #outputs: thetat, phift(z), Yst
    J=zeros((T,T,2,2+Nz));
    
    #J_thetat_pyt
    J[:,:,1,1] = diagm(0=>((1.0-md.LABOR_SHARE).*κ_prime_inv.(pyt[1:T];parms))./(Rft[1:T].*parms.K.*f_prime.(thetat[1:T];parms)));
    for t=T-1:-1:1
        for i=t+1:T
            J[t,i,1,1] = (((1.0-parms.DELTA)*f_prime(thetat[t+1];parms))/(Rft[t]*f_prime(thetat[t];parms)))*J[t+1,i,1,1]
        end
    end

    #J_thetat_Rft
    J[:,:,2,1] = diagm(0=>-((((1.0-md.LABOR_SHARE).*firm_rev.(pyt[1:T];parms))./(parms.K*f_prime.(thetat[1:T];parms))) .+ (((1.0-parms.DELTA).*f.(append!(thetat[2:T],mm_new.theta);parms))./f_prime.(thetat[1:T];parms))).*(1.0./(Rft[1:T].^2)))
    for t=T-1:-1:1
        for i=t+1:T
            J[t,i,2,1] = ((1.0-parms.DELTA)*f_prime(thetat[t+1];parms)/(Rft[t].*f_prime(thetat[t];parms)))*J[t+1,i,2,1]
        end
    end
    
    #J_phift_pyt
    for i=1:Nz
        for t=1:T
            J[t,:,1,1+i] = mm_new.zgrid[i]*parms.K.*Rft[t].*f_prime(thetat[t];parms).*J[t,:,1,1]
        end
    end

    #J_phift_Rft
    for i=1:Nz
        for t=1:T
            J[t,t,2,1+i] = mm_new.zgrid[i]*parms.K*f(thetat[t];parms) + mm_new.zgrid[i]*parms.K*Rft[t]*f_prime(thetat[t];parms)*J[t,t,2,1]
            if t<=T-1
                J[t,t+1:end,2,1+i] = mm_new.zgrid[i]*parms.K.*Rft[t].*f_prime(thetat[t];parms).*J[t,t+1:end,2,1]
            end
        end
    end

    #J_Ys_pyt
    J[:,:,1,end] = diagm(0=>1.0./κ_prime_prime.(κ_prime_inv.(pyt[1:T];parms);parms))
    
    return J
end

#Unemployment Jacobian
function unemp_jacobian(thetat,ut;parms::parmsmodel,mm::model_outcomes)
    J = zeros((T,T))
    
    J[:,:] = diagm(0=>-λ_prime.(thetat[1:T];parms).*append!([1.0-mm.emp],ut[1:T-1]))
    for t=2:T
        for i=1:t-1
            J[t,i] = (1.0-λ(thetat[t];parms)-parms.DELTA)*J[t-1,i]
        end
    end

    return J
end
    
#Jacobian system for unknowns U=(pyt,Rft,phimt) and exogenous shocks Z=(Mt,Agt)
function H_jacobian(U,Z;T,dx,mm::model_outcomes,mm_new::model_outcomes,parms::parmsmodel,md::moments_data,M_1=1.0)
    H_U = zeros((T,T,3,3));
    H_Z = zeros((T,T,2,3));
    Rmt=U[:,1];
    Rft=U[:,2];
    pyt=U[:,3];
    Mt=Z[:,1];
    Agt=Z[:,2];

    #Jacobians of DAG
    phimt = phim_block(Rmt,Mt;mm_new);
    J_phim = phim_jacobian(Rmt,phimt,Mt);
    J_w = wage_jacobian(pyt;parms,md);
    J_fiscal = fiscal_jacobian(Agt,Mt,Rft,phimt;mm,M_1=M_1);
    Yst, thetat, phift = firms_block(pyt,Rft;parms,mm_new);
    J_firms = firm_jacobian(pyt,Rft,thetat;parms,mm_new,md);
    ut = unemp_block(thetat;parms,mm);
    J_unemp = unemp_jacobian(thetat,ut;parms,mm);
    taut = fiscal_block(Rft,Rmt,phimt,Agt,Mt;mm,M_1=M_1);
    wagest = wage_block(pyt; parms,mm,md);
    J_ha = ha_jacobian(Rmt,Rft,pyt,taut,wagest,thetat;dx,mm,mm_new,parms);
    J_het=J_ha[end];

    #dH1_dpy
    H_U[:,:,3,1] = -J_het[:,:,1,1]-J_het[:,:,6,1]*J_firms[:,:,1,1];
    for k=0:Nz-1
        H_U[:,:,3,1] += -J_het[:,:,6+(2*k)+1,1]*J_w[:,:,k+1]-J_het[:,:,6+(2*k)+2,1]*J_w[:,:,k+1]*md.REPLACE_RATE
    end

    #dH1_dRf
    H_U[:,:,2,1] = -J_het[:,:,2,1]-J_het[:,:,4,1]*J_fiscal[:,:,3]-J_het[:,:,5,1]*J_fiscal[:,:,3]-J_het[:,:,6,1]*J_firms[:,:,2,1];
    
    #dH1_dRm
    H_U[:,:,1,1] = repeat((Mt[1:T].*Rmt)',T)'.*J_phim[:,:,1] + diagm(0=>Mt.*phimt) - J_het[:,:,3,1]-J_het[:,:,4,1]*J_fiscal[:,:,4]*J_phim[:,:,1]-J_het[:,:,5,1]*J_fiscal[:,:,4]*J_phim[:,:,1];

    #dH2_dpy
    H_U[:,:,3,2] += - J_het[:,:,1,2] - J_het[:,:,6,2]*J_firms[:,:,1,1];
    for k=1:Nz
        H_U[:,:,3,2] += -repeat(phift[1:T,k]',T)'.*J_unemp*J_firms[:,:,1,1].*mm_new.gz[k];
        H_U[:,:,3,2] += repeat((1.0.-ut[1:T])',T)'.*J_firms[:,:,1,1+k].*mm_new.gz[k];
    end
    for k=0:Nz-1
        H_U[:,:,3,2] += - J_het[:,:,6+(2*k)+1,2]*J_w[:,:,k+1] - J_het[:,:,6+(2*k)+2,2]*J_w[:,:,k+1]*md.REPLACE_RATE 
    end

    #dH2_dRf
    H_U[:,:,2,2] += - J_het[:,:,2,2] - J_het[:,:,4,2]*J_fiscal[:,:,3] - J_het[:,:,5,2]*J_fiscal[:,:,3] - J_het[:,:,6,2]*J_firms[:,:,2,1];
    for k=1:Nz
        H_U[:,:,2,2] += -repeat(phift[1:T]',T)'.*J_unemp*J_firms[:,:,2,1].*mm_new.gz[k];
        H_U[:,:,2,2] += repeat((1.0.-ut[1:T])',T)'.*J_firms[:,:,2,1+k].*mm_new.gz[k];
    end
    
    #dH2_dRm
    H_U[:,:,1,2] =  -J_het[:,:,3,2]-J_het[:,:,4,2]*J_fiscal[:,:,4]*J_phim[:,:,1] - J_het[:,:,5,2]*J_fiscal[:,:,4]*J_phim[:,:,1];

    #dH3_dpy
    H_U[:,:,3,3] = -repeat(Yst[1:T]',T)'.*J_unemp*J_firms[:,:,1,1]*sum(mm.zgrid.*mm.gz);
    H_U[:,:,3,3] += repeat((1.0.-ut[1:T])',T)'.*J_firms[:,:,1,end]*sum(mm.zgrid.*mm.gz);
    H_U[:,:,3,3] += - J_het[:,:,1,3] - J_het[:,:,6,3]*J_firms[:,:,1,1];
    for k=0:Nz-1
        H_U[:,:,3,3] += - J_het[:,:,6+(2*k)+1,3]*J_w[:,:,k+1] - J_het[:,:,6+(2*k)+2,3]*J_w[:,:,k+1]*md.REPLACE_RATE 
    end

    #dH3_dRf
    H_U[:,:,2,3] = repeat(Yst[1:T]',T)'.*J_unemp*J_firms[:,:,2,1]*sum(mm.zgrid.*mm.gz);
    H_U[:,:,2,3] += - J_het[:,:,2,3] - J_het[:,:,4,3]*J_fiscal[:,:,3] - J_het[:,:,5,3]*J_fiscal[:,:,3] - J_het[:,:,6,3]*J_firms[:,:,2,1];

    #dH3_dRm
    H_U[:,:,1,3] =  - J_het[:,:,3,3] - J_het[:,:,4,3]*J_fiscal[:,:,4]*J_phim[:,:,1] - J_het[:,:,5,3]*J_fiscal[:,:,4]*J_phim[:,:,1];

    #dH1_dM
    H_Z[:,:,1,1] = diagm(0=>phimt.*Rmt) + repeat((Rmt.*Mt[1:T])',T)'.*J_phim[:,:,2] - (J_het[:,:,4,1]*J_fiscal[:,:,2] + J_het[:,:,5,1]*J_fiscal[:,:,2] + J_het[:,:,4,1]*J_fiscal[:,:,4]*J_phim[:,:,2] + J_het[:,:,5,1]*J_fiscal[:,:,2]*J_phim[:,:,2]);

    #dH1_dAg
    H_Z[:,:,2,1] = -(J_het[:,:,4,1]*J_fiscal[:,:,1]+J_het[:,:,5,1]*J_fiscal[:,:,1]);

    #dH2_dM
    H_Z[:,:,1,2] = -(J_het[:,:,4,2]*J_fiscal[:,:,2] + J_het[:,:,5,2]*J_fiscal[:,:,2] + J_het[:,:,4,2]*J_fiscal[:,:,4]*J_phim[:,:,2] + J_het[:,:,5,2]*J_fiscal[:,:,4]*J_phim[:,:,2]);

    #dH2_dAg
    H_Z[:,:,2,2] = diagm(0=>ones(T)) - (J_het[:,:,4,2]*J_fiscal[:,:,1] + J_het[:,:,5,2]*J_fiscal[:,:,1]);

    #dH3_dM
    H_Z[:,:,1,3] = -(J_het[:,:,4,3]*J_fiscal[:,:,2] + J_het[:,:,5,3]*J_fiscal[:,:,2] + J_het[:,:,4,3]*J_fiscal[:,:,4]*J_phim[:,:,2] + J_het[:,:,5,3]*J_fiscal[:,:,4]*J_phim[:,:,2]);

    #dH3_dAg
    H_Z[:,:,2,3] = -(J_het[:,:,4,3]*J_fiscal[:,:,1] + J_het[:,:,5,3]*J_fiscal[:,:,1]);

    return H_U, H_Z
end

#Non-linear perfect-foresight dynamics
function NLPFD(Z,U0,Zssnew,Ussnew;T,dx,mm_base::model_outcomes,mm_new::model_outcomes,parms::parmsmodel,md::moments_data,tol,max_iter)
    
    #initialize
    Rmt=zeros(T); Rft=zeros(T); pyt=zeros(T); 
    Mt=zeros(T); Agt=zeros(T); 
    taut=zeros((T,2)); wagest=zeros((2,Nz,T));
    thetat=zeros(T); phimt=zeros(T); Yst=zeros(T); phift=zeros((T,Nz)); ut=zeros(T); 
    Ct=zeros(T); Ct1=zeros(T); Ct0=zeros(T); 
    Aft1=zeros(T); Aft0=zeros(T); 
    Amt1=zeros(T); Amt0=zeros(T);
    Ydt=zeros(T); Aft=zeros(T); Amt=zeros(T); 
    H1t=zeros(T); H2t=zeros(T); H3t=zeros(T);
    a_p_mat=zeros((T,Na,2,Nz)); am_p_mat=zeros((T,Na,2,Nz));c_mat=zeros((T,Na,2,Nz)); ystar_mat=zeros((T,Na,2,Nz));
    gt = zeros((T+1,Na,2,Nz));

    #solve Jacobians at terminal steady state
    HU, HZ = H_jacobian(Ussnew,Zssnew;T,dx,mm=mm_new,mm_new=mm_new,parms=parms,md=md,M_1=1.0);
    
    #reshape
    HhatU = zeros((3*T,3*T));
    HhatU[1:T,1:T] = HU[:,:,1,1];
    HhatU[1:T,T+1:2*T] = HU[:,:,2,1];
    HhatU[1:T,2*T+1:3*T] = HU[:,:,3,1];
    HhatU[T+1:2*T,1:T] = HU[:,:,1,2];
    HhatU[T+1:2*T,T+1:2*T] = HU[:,:,2,2];
    HhatU[T+1:2*T,2*T+1:3*T] = HU[:,:,3,2];
    HhatU[2*T+1:3*T,1:T] = HU[:,:,1,3];
    HhatU[2*T+1:3*T,T+1:2*T] = HU[:,:,2,3];
    HhatU[2*T+1:3*T,2*T+1:3*T] = HU[:,:,3,3];     
    invHhatU = inv(HhatU);
    
    #iterate
    errorU=1.0; iter=0;
    while errorU>tol && iter<max_iter
        iter+=1  

        #Update H
        Uhat = reshape(U0,(3*T,1));
        Rmt[:] = Uhat[1:T];
        Rft[:] = Uhat[T+1:2*T];
        pyt[:] = Uhat[2*T+1:3*T];
        Mt[:] = Z[:,1];
        Agt[:] = Z[:,2];
        taut[:,:], wagest[:,:,:], phimt[:], Yst[:], thetat[:], phift[:,:], ut[:], Ydt[:], Aft[:], Amt[:], Ct[:], Ct1[:], Ct0[:], Aft1[:], Aft0[:], Amt1[:], Amt0[:], a_p_mat[:,:,:,:], am_p_mat[:,:,:,:], c_mat[:,:,:,:], ystar_mat[:,:,:,:], gt[:,:,:,:] = update_trans(Rmt,Rft,pyt,Agt,Mt;mm=mm_base,mm_new=mm_new,parms=parms,md=md,M_1=1.0);
        H1t[:], H2t[:], H3t[:] = mkt_clear_trans(phimt,Rmt,Rft,Mt,Amt,ut,phift,Agt,Aft,Yst,Ydt;parms=parms,mm=mm_base,mm_new=mm_new);
        H = [H1t H2t H3t];
        Hhat = reshape(H, (3*T,1));

        #update U 
        Uhat_prime = Uhat - invHhatU*Hhat;
        U_prime = reshape(Uhat_prime,(T,3));

        #compute error
        errorU = maximum(abs.(Uhat_prime-Uhat))
        U0 = copy(U_prime);
    end

    return Rmt, Rft, pyt, Mt, Agt, taut, wagest, phimt, Yst, thetat, phift, ut, Ydt, Aft, Amt, Ct, Ct1, Ct0, Aft1, Aft0, Amt1, Amt0, a_p_mat, am_p_mat, c_mat, ystar_mat, gt, H1t, H2t, H3t
end

#END